{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "#from build_predictions import * \n",
    "#from GaussRank import GaussRankMap\n",
    "import pandas as pd \n",
    "#from create_parquet import *\n",
    "#from data import *\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\") "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_DIR='/rapids/notebooks/srabhi/champs-2019/input/'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Get the three frames needed for building train / validation oupling stacking "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Nodes frame"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "node_frame = pd.read_csv(DATA_DIR+'parquet/baseline_node_frame.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>symbol</th>\n",
       "      <th>acceptor</th>\n",
       "      <th>donor</th>\n",
       "      <th>aromatic</th>\n",
       "      <th>hybridization</th>\n",
       "      <th>num_h</th>\n",
       "      <th>atomic</th>\n",
       "      <th>num_nodes</th>\n",
       "      <th>node_dim</th>\n",
       "      <th>molecule_name</th>\n",
       "      <th>atom_index</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>21</td>\n",
       "      <td>7</td>\n",
       "      <td>dsgdb9nsd_101594</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>3.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>21</td>\n",
       "      <td>7</td>\n",
       "      <td>dsgdb9nsd_101594</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   symbol  acceptor  donor  aromatic  hybridization  num_h  atomic  num_nodes  \\\n",
       "0     2.0       0.0    0.0       0.0            3.0    3.0     6.0         21   \n",
       "1     3.0       0.0    1.0       0.0            2.0    1.0     7.0         21   \n",
       "\n",
       "   node_dim     molecule_name  atom_index  \n",
       "0         7  dsgdb9nsd_101594           0  \n",
       "1         7  dsgdb9nsd_101594           1  "
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "node_frame.head(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Molecule node frame "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>molecule_name</th>\n",
       "      <th>num_nodes</th>\n",
       "      <th>node_dim</th>\n",
       "      <th>node_0</th>\n",
       "      <th>node_1</th>\n",
       "      <th>node_2</th>\n",
       "      <th>node_3</th>\n",
       "      <th>node_4</th>\n",
       "      <th>node_5</th>\n",
       "      <th>node_6</th>\n",
       "      <th>...</th>\n",
       "      <th>node_214</th>\n",
       "      <th>node_215</th>\n",
       "      <th>node_216</th>\n",
       "      <th>node_217</th>\n",
       "      <th>node_218</th>\n",
       "      <th>node_219</th>\n",
       "      <th>node_220</th>\n",
       "      <th>node_221</th>\n",
       "      <th>node_222</th>\n",
       "      <th>node_223</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>dsgdb9nsd_000001</td>\n",
       "      <td>5</td>\n",
       "      <td>7</td>\n",
       "      <td>2.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>dsgdb9nsd_000002</td>\n",
       "      <td>4</td>\n",
       "      <td>7</td>\n",
       "      <td>3.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>2 rows × 227 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      molecule_name  num_nodes  node_dim  node_0  node_1  node_2  node_3  \\\n",
       "0  dsgdb9nsd_000001          5         7     2.0     0.0     0.0     0.0   \n",
       "1  dsgdb9nsd_000002          4         7     3.0     0.0     1.0     0.0   \n",
       "\n",
       "   node_4  node_5  node_6  ...  node_214  node_215  node_216  node_217  \\\n",
       "0     3.0     4.0     6.0  ...       0.0       0.0       0.0       0.0   \n",
       "1     3.0     3.0     7.0  ...       0.0       0.0       0.0       0.0   \n",
       "\n",
       "   node_218  node_219  node_220  node_221  node_222  node_223  \n",
       "0       0.0       0.0       0.0       0.0       0.0       0.0  \n",
       "1       0.0       0.0       0.0       0.0       0.0       0.0  \n",
       "\n",
       "[2 rows x 227 columns]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "node_cols = ['symbol','acceptor', 'donor', 'aromatic',  'hybridization', 'num_h', 'atomic']\n",
    "shared_cols = ['molecule_name', 'num_nodes', 'node_dim']\n",
    "tmp = node_frame.groupby('molecule_name').apply(lambda x: x[node_cols].values.reshape(-1))\n",
    "molecule_node = pd.DataFrame(tmp.values.tolist())\n",
    "#pad node max 29 to 32 \n",
    "pad_cols = 21\n",
    "d = dict.fromkeys([str(i) for i in range(molecule_node.shape[1], molecule_node.shape[1]+pad_cols)], 0.0)\n",
    "molecule_node = molecule_node.assign(**d).fillna(0.0)\n",
    "molecule_node['molecule_name'] = tmp.index\n",
    "molecule_node = molecule_node.merge(node_frame[shared_cols].drop_duplicates(), on='molecule_name', how='left')\n",
    "cols = molecule_node.columns.tolist()\n",
    "new_cols = cols[-3:] + cols[:-3]\n",
    "molecule_node = molecule_node[new_cols]\n",
    "molecule_node.columns = ['molecule_name', 'num_nodes', 'node_dim'] + ['node_%s'%i for i in range(NODE_MAX*7)]\n",
    "molecule_node.head(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Edges frame"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "edge_frame = pd.read_csv(DATA_DIR+'parquet/baseline_edge_frame.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>atom_index_0</th>\n",
       "      <th>atom_index_1</th>\n",
       "      <th>edge_type</th>\n",
       "      <th>distance</th>\n",
       "      <th>angle</th>\n",
       "      <th>molecule_name</th>\n",
       "      <th>num_edge</th>\n",
       "      <th>edge_dim</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.450358</td>\n",
       "      <td>-0.498035</td>\n",
       "      <td>dsgdb9nsd_101594</td>\n",
       "      <td>420</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2.454559</td>\n",
       "      <td>-0.499508</td>\n",
       "      <td>dsgdb9nsd_101594</td>\n",
       "      <td>420</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   atom_index_0  atom_index_1  edge_type  distance     angle  \\\n",
       "0           0.0           1.0        1.0  1.450358 -0.498035   \n",
       "1           0.0           2.0        0.0  2.454559 -0.499508   \n",
       "\n",
       "      molecule_name  num_edge  edge_dim  \n",
       "0  dsgdb9nsd_101594       420         5  \n",
       "1  dsgdb9nsd_101594       420         5  "
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "edge_frame.head(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Molecule edge frame "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>molecule_name</th>\n",
       "      <th>num_edge</th>\n",
       "      <th>edge_dim</th>\n",
       "      <th>edge_0</th>\n",
       "      <th>edge_1</th>\n",
       "      <th>edge_2</th>\n",
       "      <th>edge_3</th>\n",
       "      <th>edge_4</th>\n",
       "      <th>edge_5</th>\n",
       "      <th>edge_6</th>\n",
       "      <th>...</th>\n",
       "      <th>edge_4070</th>\n",
       "      <th>edge_4071</th>\n",
       "      <th>edge_4072</th>\n",
       "      <th>edge_4073</th>\n",
       "      <th>edge_4074</th>\n",
       "      <th>edge_4075</th>\n",
       "      <th>edge_4076</th>\n",
       "      <th>edge_4077</th>\n",
       "      <th>edge_4078</th>\n",
       "      <th>edge_4079</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>dsgdb9nsd_000001</td>\n",
       "      <td>20</td>\n",
       "      <td>5</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.091953</td>\n",
       "      <td>-0.901529</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>dsgdb9nsd_000002</td>\n",
       "      <td>12</td>\n",
       "      <td>5</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.017190</td>\n",
       "      <td>0.292853</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>2 rows × 4083 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      molecule_name  num_edge  edge_dim  edge_0  edge_1  edge_2    edge_3  \\\n",
       "0  dsgdb9nsd_000001        20         5     0.0     1.0     1.0  1.091953   \n",
       "1  dsgdb9nsd_000002        12         5     0.0     1.0     1.0  1.017190   \n",
       "\n",
       "     edge_4  edge_5  edge_6  ...  edge_4070  edge_4071  edge_4072  edge_4073  \\\n",
       "0 -0.901529     0.0     2.0  ...        0.0        0.0        0.0        0.0   \n",
       "1  0.292853     0.0     2.0  ...        0.0        0.0        0.0        0.0   \n",
       "\n",
       "   edge_4074  edge_4075  edge_4076  edge_4077  edge_4078  edge_4079  \n",
       "0        0.0        0.0        0.0        0.0        0.0        0.0  \n",
       "1        0.0        0.0        0.0        0.0        0.0        0.0  \n",
       "\n",
       "[2 rows x 4083 columns]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "edge_cols = ['atom_index_0', 'atom_index_1', 'edge_type', 'distance', 'angle' ]\n",
    "shared_cols = ['molecule_name', 'num_edge', 'edge_dim']\n",
    "tmp = edge_frame.groupby('molecule_name').apply(lambda x: x[edge_cols].values.reshape(-1))\n",
    "molecule_edge = pd.DataFrame(tmp.values.tolist())\n",
    "#pad edge_max 812 to 816\n",
    "pad_cols = 4 * 5\n",
    "d = dict.fromkeys([str(i) for i in range(molecule_edge.shape[1], molecule_edge.shape[1]+pad_cols)], 0.0)\n",
    "molecule_edge = molecule_edge.assign(**d).fillna(0.0)\n",
    "molecule_edge['molecule_name'] = tmp.index\n",
    "molecule_edge = molecule_edge.merge(edge_frame[shared_cols].drop_duplicates(), on='molecule_name', how='left')\n",
    "cols = molecule_edge.columns.tolist()\n",
    "new_cols = cols[-3:] + cols[:-3]\n",
    "molecule_edge = molecule_edge[new_cols]\n",
    "molecule_edge.columns = ['molecule_name', 'num_edge', 'edge_dim']+ ['edge_%s'%i for i in range(EDGE_MAX*5)]\n",
    "molecule_edge.head(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Coupling frame**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "coupling_frame = pd.read_csv(DATA_DIR+'parquet/baseline_coupling_frame.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>atom_index_0</th>\n",
       "      <th>atom_index_1</th>\n",
       "      <th>coupling_type</th>\n",
       "      <th>scalar_coupling</th>\n",
       "      <th>fc</th>\n",
       "      <th>sd</th>\n",
       "      <th>pso</th>\n",
       "      <th>dso</th>\n",
       "      <th>id</th>\n",
       "      <th>num_coupling</th>\n",
       "      <th>coupling_dim</th>\n",
       "      <th>molecule_name</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>9.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>84.50560</td>\n",
       "      <td>83.117900</td>\n",
       "      <td>0.133123</td>\n",
       "      <td>0.548422</td>\n",
       "      <td>0.706185</td>\n",
       "      <td>3528934.0</td>\n",
       "      <td>64</td>\n",
       "      <td>9</td>\n",
       "      <td>dsgdb9nsd_101594</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>9.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>-0.42038</td>\n",
       "      <td>-0.454646</td>\n",
       "      <td>-0.012691</td>\n",
       "      <td>0.098126</td>\n",
       "      <td>-0.051169</td>\n",
       "      <td>3528935.0</td>\n",
       "      <td>64</td>\n",
       "      <td>9</td>\n",
       "      <td>dsgdb9nsd_101594</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   atom_index_0  atom_index_1  coupling_type  scalar_coupling         fc  \\\n",
       "0           9.0           0.0            0.0         84.50560  83.117900   \n",
       "1           9.0           1.0            4.0         -0.42038  -0.454646   \n",
       "\n",
       "         sd       pso       dso         id  num_coupling  coupling_dim  \\\n",
       "0  0.133123  0.548422  0.706185  3528934.0            64             9   \n",
       "1 -0.012691  0.098126 -0.051169  3528935.0            64             9   \n",
       "\n",
       "      molecule_name  \n",
       "0  dsgdb9nsd_101594  \n",
       "1  dsgdb9nsd_101594  "
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "coupling_frame.head(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Build Train / validation gaussrank value for each fold "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "from time import time\n",
    "def save_cv_data(fold, coupling_frame):\n",
    "    print('fold: %s' %fold)\n",
    "    split_train = 'train_split_by_mol_hash.%s.npy'%fold\n",
    "    split_valid = 'valid_split_by_mol_hash.%s.npy'%fold\n",
    "    id_train_ = np.load(DATA_DIR + '/split/%s'%split_train,allow_pickle=True)\n",
    "    id_valid_ = np.load(DATA_DIR + '/split/%s'%split_valid,allow_pickle=True)\n",
    "    csv = 'test'\n",
    "    df = pd.read_csv(DATA_DIR + '/csv/%s.csv'%csv)\n",
    "    id_test_ = df.molecule_name.unique()\n",
    "    \n",
    "    train = coupling_frame[coupling_frame.molecule_name.isin(id_train_)]\n",
    "    validation = coupling_frame[coupling_frame.molecule_name.isin(id_valid_)]\n",
    "    test = coupling_frame[coupling_frame.molecule_name.isin(id_test_)]\n",
    "    \n",
    "    # Get GaussRank of coupling values \n",
    "    t0 = time()\n",
    "    grm = GaussRankMap()\n",
    "    df_train = train[['coupling_type', 'scalar_coupling']]\n",
    "    df_valid = validation[['coupling_type', 'scalar_coupling']]\n",
    "    df_train.columns = ['type', 'scalar_coupling_constant']\n",
    "    df_valid.columns = ['type', 'scalar_coupling_constant']\n",
    "    # Reverse type mapping \n",
    "    df_train.type = df_train.type.map(REVERSE_COUPLING_TYPE)\n",
    "    df_valid.type = df_valid.type.map(REVERSE_COUPLING_TYPE)\n",
    "    #fit grm \n",
    "    transformed_training = grm.fit_training(df_train, reset=True)\n",
    "    transformed_validation = grm.convert_df(df_valid, from_coupling=True)\n",
    "    validation['gaussrank_coupling'] =  transformed_validation\n",
    "    train['gaussrank_coupling'] = transformed_training\n",
    "    print('Getting gaussrank transformation for train/validation data took %s seconds' %(time()-t0))\n",
    "    print(grm.coupling_order)\n",
    "    test['gaussrank_coupling'] = 0 \n",
    "\n",
    "    general_coupling_frame = pd.concat([train, validation, test])\n",
    "\n",
    "    # Build molecule coupling frame for fold \n",
    "    coupling_cols = ['atom_index_0', 'atom_index_1', 'coupling_type', 'scalar_coupling', 'gaussrank_coupling',  'fc', 'sd', 'pso', 'dso', 'id']\n",
    "    shared_cols = ['molecule_name', 'num_coupling', 'coupling_dim']\n",
    "\n",
    "    tmp = general_coupling_frame.groupby('molecule_name').apply(lambda x: x[coupling_cols].values.reshape(-1))\n",
    "    molecule_coupling = pd.DataFrame(tmp.values.tolist())\n",
    "    # pad coupling_max from 135 to 136\n",
    "    pad_cols = 10\n",
    "    d = dict.fromkeys([str(i) for i in range(molecule_coupling.shape[1], molecule_coupling.shape[1]+pad_cols)], 0.0)\n",
    "    molecule_coupling = molecule_coupling.assign(**d).fillna(0.0)\n",
    "    molecule_coupling['molecule_name'] = tmp.index\n",
    "    molecule_coupling = molecule_coupling.merge(general_coupling_frame[shared_cols].drop_duplicates(), on='molecule_name', how='left')\n",
    "\n",
    "    cols = molecule_coupling.columns.tolist()\n",
    "    new_cols = cols[-3:] + cols[:-3]\n",
    "    molecule_coupling = molecule_coupling[new_cols]\n",
    "    molecule_coupling.columns = ['molecule_name', 'num_coupling', 'coupling_dim'] + ['coupling_%s'%i for i in range(COUPLING_MAX*10)]\n",
    "\n",
    "    print(molecule_coupling.shape, molecule_edge.shape, molecule_node.shape)\n",
    "\n",
    "    node_edge_frame = pd.merge(molecule_node, molecule_edge, on='molecule_name', how='left')\n",
    "    general_stack_frame =  pd.merge(node_edge_frame, molecule_coupling, on='molecule_name', how='left')\n",
    "\n",
    "    train_frame = general_stack_frame[general_stack_frame.molecule_name.isin(id_train_)]\n",
    "    validation_frame = general_stack_frame[general_stack_frame.molecule_name.isin(id_valid_)]\n",
    "    test_frame = general_stack_frame[general_stack_frame.molecule_name.isin(id_test_)]\n",
    "\n",
    "    validation_frame.to_parquet(DATA_DIR+'/parquet/fold_%s'%fold+'/validation.parquet')\n",
    "    train_frame.to_parquet(DATA_DIR +'/parquet/fold_%s'%fold+ '/train.parquet')\n",
    "    # save mapping\n",
    "    for i, (type_, frame) in enumerate(zip(grm.coupling_order, grm.training_maps)): \n",
    "        frame.to_csv(DATA_DIR +'/parquet/fold_%s'%fold+'/mapping_type_%s_order_%s.csv'%(type_, i), index=False)\n",
    "    return test_frame"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fold: 1\n",
      "Getting gaussrank transformation for train/validation data took 7.560222148895264 seconds\n",
      "['1JHC', '2JHC', '3JHC', '2JHH', '3JHH', '1JHN', '3JHN', '2JHN']\n",
      "(130775, 1363) (130775, 4083) (130775, 227)\n"
     ]
    }
   ],
   "source": [
    "test_frame = save_cv_data(1, coupling_frame) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "folds = 4\n",
    "for fold in range(1, folds):\n",
    "    print(coupling_frame.shape)\n",
    "    test_frame = save_cv_data(fold, coupling_frame)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_frame.to_parquet(DATA_DIR +'/parquet/test.parquet')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
